import gym
from gym.wrappers import Monitor
# from option import Option
from env.mujoco_env.reacher_env import ReacherGymEnv
from env.mujoco_env.reacher_env import RMReacherGymEnv
from env.mujoco_env.reacher_env import ReacherGymEnvEval
import os
import torch
import gym
from monitor import Monitor
from option import *
import numpy as np
from pathlib import Path
import contextlib

def save_dataset(exp_name, method_name, task_name, epoch, results):
    directory = Path(__file__).parent / 'dataset' / exp_name / method_name / task_name 
    # if directory doesn't exist, create it
    Path(directory).mkdir(parents=True, exist_ok=True)
    file_name = str(epoch) + '.npz'
    path_name = directory / file_name
    np.savez(path_name, results)

def test_all_epochs(metapolicy_name, nF, task_name):
    epochs = [i for i in range(0, 990, 10)] + [999]
    num_tests = 10
    for epoch in epochs:
        results = test_epoch(epoch, nF, task_name, metapolicy_name, num_tests=num_tests)
        save_dataset('satisfaction', metapolicy_name, task_name, epoch, results)

def test_epoch(epoch, nF, task_name, metapolicy_name, num_tests=10):
    option_load_path = os.path.join(os.environ['LOF_PKG_PATH'], 'experiments',
        'rm', task_name, 'pyt_save', 'model{}.pt'.format(epoch))

    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        option = Option(option_load_path)

    print("MODEL {} | TESTING EPOCH {}".format(metapolicy_name, epoch))
    rewards, final_fsa_states, successes = run_rollout(nF, task_name, option, num_tests)

    results = {'reward': rewards, 'epoch': epoch, 'success': successes, 'last_state': final_fsa_states}

    return results

###############
# Run Rollout #
###############
def run_rollout(nF, task_name, policy, num_episodes):
    goal_state = nF - 1

    rewards = []
    final_fsa_states = []
    successes = []
    env = RMReacherGymEnv(nF=nF, task_name=task_name, training=False, env_config={'headless': True, 'horizon': 800})

    for i in range(num_episodes):
        if i % 2 == 0:
            cancel = False
        else:
            cancel = True
        # env = Monitor(env, './video', video_callable=lambda episode_id: True, force=True)

        task_done = False
        R = 0
        obs = env.reset()

        f = 0
        num_steps = 0
        prev_f = f
        success = False
        terminated = False

        while not task_done:
            # env.render()
            cancel = False
            # cancel only on ODD episodes, and only when the FSA
            # transitions from the initial state to the next state
            if i % 2 == 1 and prev_f == 0 and f != 0:
                cancel = True
                print("CANCELLED")

            prev_f = f

            a = policy.get_action(torch.from_numpy(obs).float())
            obs, reward, task_done, info = env.step(a, cancel=cancel)
            # print("FSA: {} | Goal: {} | reward {}".format(f, color, reward))
            prev_f = f
            f = obs[0]
            R += reward

            if f == goal_state:
                success = True
                env.set_task_done(True)

        rewards.append(R)
        final_fsa_states.append(f)
        successes.append(success)      
        print(f"Episode {i} return: {R} | FSA: {f}")

    env.close()
    return rewards, final_fsa_states, successes

#######
# Run #        
#######
def run_all_tests():
    nFs = [7, 5, 5, 3]
    task_names = ['composite', 'sequential', 'IF', 'OR']
    # nFs = [3]
    # task_names = ['OR']

    for nF, task_name in zip(nFs, task_names):
        print("TEST {} {}".format('RM', task_name))
        test_all_epochs('RM', nF, task_name)

if __name__ == '__main__':
    run_all_tests()